#!/usr/bin/env python3

import argparse
import sys
import unittest
import elftools.elf.elffile
import elftools.elf.enums
import elftools.elf.sections

QEMU_EXT_OPTS = {
  "zba":             "zba=true",
  "zbb":             "zbb=true",
  "zbc":             "zbc=true",
  "zbs":             "zbs=true",
  "v":               "v=true,vext_spec=v1.0",
  "zve32f":          "Zve32f=true",
  "zve64f":          "Zve64f=true",
  "zfh":             "Zfh=true",
  "zfhmin":          "Zfhmin=true",
  "zhinx":           "zhinx=true",
  "zfinx":           "zfinx=true",
  "zdinx":           "zdinx=true",
}

SPIKE_EXT_NOT_ALLOWED = [
  "zve32x",
  "zve32f",
  "zve64x",
  "zve64f",
  "zve64d",
  "zvl32b",
  "zvl64b",
  "zvl128b",
  "zvl256b",
  "zvl512b",
  "zvl1024b",
  "zvl2048b",
  "zvl4096b",
]

CPU_OPTIONS = {
  "xlen":            "",
  "vlen":            "",
  "elen":            "",
  "extensions":      [],
}

SUPPORTTED_EXTS = "iemafdcbvph"
MC_EXT_PREFIX = "zsx"

def parse_opt(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('-march', '--with-arch', type=str, dest='march')
    parser.add_argument('-selftest', action='store_true')
    parser.add_argument('--elf-file-path', type=str)
    parser.add_argument('--print-xlen', action='store_true', default=False)
    parser.add_argument('--print-vlen', action='store_true', default=False)
    parser.add_argument('--print-qemu-cpu', action='store_true', default=False)
    parser.add_argument('--print-spike-isa', action='store_true', default=False)
    parser.add_argument('--print-spike-varch', action='store_true',
                        default=False)
    opt = parser.parse_args()
    return opt

def parse_mc_ext(ext_str, idx):
    end_idx = ext_str[idx+1:].find('_')
    if end_idx == -1:
        end_idx = len(ext_str)
    else:
        end_idx = end_idx + idx + 1
    major = 0
    minor = 0
    version_begin_idx = end_idx
    if ext_str[end_idx-1].isdigit():
        # This ext is come with version.
        v_idx = end_idx - 1
        while (ext_str[v_idx].isdigit()) and v_idx > idx:
            v_idx -= 1
        major = int(ext_str[v_idx+1:end_idx])
        version_begin_idx = v_idx+1
        if (ext_str[v_idx] == 'p'):
            minor = major
            major_v_idx = v_idx - 1
            while (ext_str[major_v_idx].isdigit()) and major_v_idx > idx:
                major_v_idx -= 1
            major = int(ext_str[major_v_idx+1:v_idx])
            version_begin_idx = major_v_idx+1

    return end_idx, ext_str[idx:version_begin_idx], major, minor

def parse_version(ext_str, idx):
    major = 2
    minor = 0
    strlen = len(ext_str)
    end_idx = idx + 1
    if idx+1 < strlen and ext_str[idx+1].isdigit():
        v_idx = idx + 1
        while v_idx < strlen and (ext_str[v_idx].isdigit()):
            v_idx += 1
        major = int(ext_str[idx+1:v_idx])
        end_idx = v_idx
        if (ext_str[v_idx] == 'p'):
            minor_v_idx = v_idx + 1
            while minor_v_idx < strlen and  (ext_str[minor_v_idx].isdigit()):
                minor_v_idx += 1
            minor = int(ext_str[v_idx+1:minor_v_idx])
            end_idx = minor_v_idx

    return end_idx, ext_str[idx], major, minor

def parse_march(march):
    if len(march) < 5:
        return None
    march = march.replace("rv64g", "rv64imafd").replace("rv32g", "rv32imafd")
    if march[0:5] not in ['rv64i', 'rv32i', 'rv32e']:
        print (march[0:5])
        return None

    ext_str = march[4:]
    idx = 0
    extstrlens = len(ext_str)
    exts = dict()
    while idx < extstrlens:
        if ext_str[idx] in SUPPORTTED_EXTS:
            idx, ext_name, major, minor = parse_version(ext_str, idx)
        elif ext_str[idx] in MC_EXT_PREFIX:
            idx, ext_name, major, minor = parse_mc_ext(ext_str, idx)
        elif ext_str[idx] == '_':
            idx = idx + 1
            continue
        else:
            raise Exception("Unrecognized ext : `%s`, %s" %
                            (ext_str[idx], ext_str))
        exts[ext_name] = (major, minor)
    return exts

def get_vlen(ext_dict):
    vlen = 0
    for ext in ext_dict.keys():
        if ext == 'v':
            vlen = max(vlen, 128)
        elif (ext.startswith('zvl') and ext[-1] == 'b'):
            zvlen = int(ext[3:-1])
            vlen = max(vlen, zvlen)
        elif ext.startswith("zve"):
            zvelen = int(ext[3:-1])
            vlen = max(vlen, zvelen)
    return vlen

def get_elen(ext_dict, xlen):
    elen = xlen

    if "zve32x" in ext_dict or "zve32f" in ext_dict:
        elen = 32
    if "zve64x" in ext_dict or "zve64f" in ext_dict or "zve64d" in ext_dict:
        elen = 64

    return elen

def print_qemu_cpu():
    cpu_options = []
    cpu_options.append("rv{0}".format(CPU_OPTIONS['xlen']))

    if CPU_OPTIONS['vlen']:
        cpu_options.append("vlen={0}".format(CPU_OPTIONS['vlen']))

    disable_all_fd = False
    for ext in CPU_OPTIONS['extensions']:
        if ext in QEMU_EXT_OPTS:
            cpu_options.append(QEMU_EXT_OPTS[ext])

        if ext in ['zhinx', 'zfinx', 'zdinx']:
            disable_all_fd = True

    if disable_all_fd:
        cpu_options.append("f=false")
        cpu_options.append("d=false")

    return ",".join(cpu_options)

def print_spike_isa():
    cpu_options = []
    cpu_options.append("rv{0}".format(CPU_OPTIONS['xlen']))

    for ext in CPU_OPTIONS['extensions']:
        if ext not in SPIKE_EXT_NOT_ALLOWED:
            normalized_ext_name = ext

            if len(ext) > 1:
                normalized_ext_name = "_{0}".format(normalized_ext_name)

            cpu_options.append(normalized_ext_name)

    return "".join(cpu_options)

def print_spike_varch():
    if not CPU_OPTIONS['vlen']:
        return ""

    return "vlen:{0},elen:{1}".format(CPU_OPTIONS['vlen'], CPU_OPTIONS['elen'])

class TestArchStringParse(unittest.TestCase):
    def _test(self, arch, expected_arch_list, expected_vlen=0):
         exts = parse_march(arch)
         vlen = get_vlen(exts)
         self.assertEqual(expected_vlen, vlen)
         self.assertEqual(set(expected_arch_list), set(exts.keys()))

    def test_rv64gc(self):
        self._test("rv64gc", ['i', 'm', 'a', 'f', 'd', 'c'])
        self._test("rv32imc_zve32x", ['i', 'm', 'c', 'zve32x'], expected_vlen=32)
        self._test("rv32imc_zve32x_zvl128b", ['i', 'm', 'c', 'zve32x', 'zvl128b'], expected_vlen=128)


def selftest():
    unittest.main(argv=sys.argv[1:])

def open_elf(path):
    try:
        elffile = elftools.elf.elffile.ELFFile(open(path, 'rb'))
    except elftools.common.exceptions.ELFError:
        raise Exception("%s is not ELF file!" % path)
    return elffile

def get_xlen(path):
    elffile = open_elf(path)
    return elffile.elfclass

def read_arch_attr (path):
    elffile = open_elf(path)

    attr_sec = elffile.get_section_by_name(".riscv.attributes")
    if attr_sec:
        # pyelftools has support RISC-V attribute but not contain in any
        # release yet, so use ARMAttributesSection for now...
        xattr_section = \
            elftools.elf.sections.ARMAttributesSection (
                attr_sec.header,
                attr_sec.name,
                elffile)
        for subsec in xattr_section.subsections:
            for subsubsec in subsec.subsubsections:
                for attr in subsubsec.iter_attributes():
                    val = attr.value
                    if (not isinstance(val, str)):
                        continue
                    pos32 = val.find("rv32")
                    pos64 = val.find("rv64")
                    # MAGIC WORKAROUND
                    # Some version of pyelftools has issue for parsing
                    # Tag number = 5, it will wrongly parse it become
                    # Tag number = 4 + 0x10 + 0x5
                    if (pos32 == 2) or (pos64 == 2):
                        val = val[2:]
                    # End of MAGIC WORKAROUND

                    if (pos32 != -1 or pos64 != -1):
                        return val
    raise Exception("Not found ELF attribute in %s?" % path)

def parse_elf_file(elf_file_path):
    extensions = []
    extension_dict = parse_march(read_arch_attr(elf_file_path))

    for extension in extension_dict.keys():
        extensions.append(extension)

    xlen = get_xlen(elf_file_path)

    CPU_OPTIONS["extensions"] = extensions
    CPU_OPTIONS["vlen"] = get_vlen(extension_dict)
    CPU_OPTIONS["elen"] = get_elen(extension_dict, xlen)
    CPU_OPTIONS["xlen"] = xlen

def main(argv):
    opt = parse_opt(argv)
    if opt.selftest:
        selftest()
        return 0

    parse_elf_file(opt.elf_file_path)

    if opt.print_xlen:
        print(CPU_OPTIONS['xlen'])
        return

    if opt.print_vlen:
        print(CPU_OPTIONS['vlen'])
        return

    if opt.print_qemu_cpu:
        print(print_qemu_cpu())

    if opt.print_spike_isa:
        print(print_spike_isa())

    if opt.print_spike_varch:
        print(print_spike_varch())

if __name__ == '__main__':
    sys.exit(main(sys.argv))
